//
// Created by neo on 25-6-24.
//

#include "Infer.h"

#include "runtime/log/Log.h"
#include "runtime/utils/TimeUtils.h"

bool Infer::Init() {
  Logger(Logger::DEBUG) << "Infer Init......" << std::endl;
  const uint64_t initTimeStart = TimeUtils::GetCurrentMonoMs();
  ce = std::make_shared<sgl::compute::ComputeEngine>();
  if (!ce->Init()) {
    std::cerr << "Failed to initialize engine" << std::endl;
    return false;
  }

  this->tokenizer = std::make_shared<Tokenizer>();
  bool ok = tokenizer->LoadFromFile("../../../examples/inference_demo/models/"
                                    "Qwen3_0_6B/Qwen3-0.6B/tokenizer.json");
  if (!ok) {
    Logger() << "Failed to load tokenizer.json";
    return false;
  }

  this->config = std::make_shared<Config>();
  ok = config->LoadFromFile("../../../examples/inference_demo/models/"
                            "Qwen3_0_6B/Qwen3-0.6B/config.json");
  if (!ok) {
    Logger() << "Failed to load config.json";
    return false;
  }

  this->safeTensor = std::make_shared<SafeTensor>(this->config);
  ok = safeTensor->LoadFromFile("../../../examples/inference_demo/models/"
                                "Qwen3_0_6B/Qwen3-0.6B/model.safetensors");
  if (!ok) {
    Logger() << "Failed to load safetensors";
    return false;
  }

  this->model =
      std::make_shared<Model>(this->ce, this->config, this->safeTensor);
  ok = model->Init();
  if (!ok) {
    Logger() << "Failed to init model";
    return false;
  }
  const uint64_t initTimeEnd = TimeUtils::GetCurrentMonoMs();
  Logger(Logger::DEBUG) << "Infer Init done! time: "
                        << initTimeEnd - initTimeStart << std::endl
                        << std::endl;
  return true;
}

void Infer::Run(const std::string &prompt) const {
  const uint64_t decodeTimeStart = TimeUtils::GetCurrentMonoMs();
  const auto result = tokenizer->Encode(prompt);
  const uint64_t decodeTimeEnd = TimeUtils::GetCurrentMonoMs();
  Logger() << "Decode time: " << decodeTimeEnd - decodeTimeStart << std::endl;

  const uint64_t inferTimeStart = TimeUtils::GetCurrentMonoMs();
  std::vector<std::vector<float>> inputs;
  for (int pos : result) {
    const std::vector<float> embedding = this->safeTensor->EmbeddingToken(pos);
    Logger() << "Token: " << pos << std::endl;
    Logger() << "Embedding(" << embedding.size() << "): [";
    for (const float e : embedding) {
      std::cout << e << " ";
    }
    std::cout << "]" << std::endl;
    inputs.push_back(embedding);
  }
  const auto output = model->Forward(inputs);
  Logger() << "Output(" << output.size() << "): [";
  for (const float o : output) {
    std::cout << o << " ";
  }
  std::cout << "]" << std::endl;
  const uint64_t inferTimeEnd = TimeUtils::GetCurrentMonoMs();
  Logger() << "Infer time: " << inferTimeEnd - inferTimeStart << std::endl
           << std::endl;

  this->model->Dump();
}
